#include "IWRetinex.h"


void Retinex(ip_Image* ori_image, ip_Image* ret_image, int image_byte)
{
	Multi_Scale_Retinex( (BYTE*)ori_image->data , ori_image->width , ori_image->height , image_byte); 
}

/* 
* calculate scale values for desired distribution.
*/  
static void 
retinex_scales_distribution(float* scales, int nscales, int mode, int s)
{
	
	if (nscales == 1)
    { //* For one filter we choose the median scale 
		scales[0] = (float)( s / 2);
	}
	else if (nscales == 2)
	{ //* For two filters whe choose the median and maximum scale 
		scales[0] = (float)(s / 2);
		scales[1] = (float) s;
	}
	  
	else
	{
		int size_step = s / nscales;
		int   i;
		
		switch(mode)
		{
		case RETINEX_UNIFORM:
			for(i = 0; i < nscales; ++i)
				scales[i] = (float) (2 +  i * size_step); 
			break;
			
		case RETINEX_LOW:
			 
			size_step =  (int) ( math_logn(s - 2) / nscales );
			
			for (i = 0; i < nscales; ++i)
				scales[i] =  (float)(2 + pow (10, (i * size_step) / math_logn(10)));
			
			break;
			
		case RETINEX_HIGH:
			
			size_step =  (int) (math_logn(s - 2) / nscales);
			for (i = 0; i < nscales; ++i)
				scales[i] = (float)(s - pow (10, (i * size_step) / math_logn(10)));
			
			break;
			
		default:
			for(i = 0; i < nscales; ++i)
				scales[i] =  (float) (2 + i * size_step);
			break;
		}
	}
	
} 

/*
* Calculate the coefficients for the recursive filter algorithm
* Fast Computation of gaussian blurring.
*/
static void
compute_coefs3 (gauss3_coefs *c, /*float*/ float  sigma)
{

	double q, q2, q3; 
	
	if (sigma >= 2.5)
    {
		q = 0.98711 * sigma - 0.96330 ;
    }
	
	else if ((sigma >= 0.5) && (sigma < 2.5))
    { 
		q = 3.97156 - 4.14554 *  ( 1 - 0.26891 * sigma) / 2;
    }
	 
	else
    {
		q = 0.1147705018520355224609375;
    }
	
	q2 = q * q;
	q3 = q * q2;
	
	c->b[0] = (1.57825+(2.44413*q)+(1.4281 *q2)+(0.422205*q3));
	c->b[1] = (        (2.44413*q)+(2.85619*q2)+(1.26661 *q3));
	c->b[2] = (                   -((1.4281*q2)+(1.26661 *q3)));
	c->b[3] = (                                 (0.422205*q3));
	c->B = 1.0-((c->b[1]+c->b[2]+c->b[3])/c->b[0]);
	c->sigma = sigma;
	c->n = 3;
} 

static void
gausssmooth (double *in, double *out, int size, int rowstride, gauss3_coefs *c)
{
	
/*
* Papers:  "Recursive Implementation of the gaussian filter.",
*          
* formula: 9a        forward filter
*          9b        backward filter
*/
	int i,n, bufsize;
	float *w1,*w2;
	
	//* forward pass 
	bufsize = size+3;
	size -= 1;
	w1 = (float *) mem_malloc (bufsize * sizeof (float));
	w2 = (float *) mem_malloc (bufsize * sizeof (float));
	w1[0] = (float)in[0];
	w1[1] = (float)in[0];
	w1[2] = (float)in[0];
	for ( i = 0 , n=3; i <= size ; i++, n++)
	{
		w1[n] = (float)(c->B*in[i*rowstride] +
			((c->b[1]*w1[n-1] +
			c->b[2]*w1[n-2] +
			c->b[3]*w1[n-3] ) / c->b[0]));
    }
	
	//* backward pass 
	w2[size+1]= w1[size+3];
	w2[size+2]= w1[size+3];
	w2[size+3]= w1[size+3];
	for (i = size, n = i; i >= 0; i--, n--)
    {
		w2[n] = out[i * rowstride]  = (float)(c->B*w1[n] +
			((c->b[1]*w2[n+1] +
			c->b[2]*w2[n+2] +
			c->b[3]*w2[n+3] ) / c->b[0]));
    }
	
	mem_free (w1);
	mem_free (w2);
	
	
}

 
/*
* Calculate the average and variance in one go.
*/

static void
compute_mean_var (double *src, double *mean, double *var, int size, int bytes) {
	
	int s =0;
    int ss=0;
	int i = 0;
	int j = 0; 
	
    for(i=0; i<size; i += bytes){
        double *psrc = src + i ;
        for( j=0; j<3; j++){
            int a = (int)psrc[j];
            s += a ;
            ss += a * a ;
        }
    }
	
    *var = ( ss - (s * s) /size ) /size;
    *mean = s / size;
}



/*
* This function is the heart of the algo.
* (a)  Filterings at several scales and sumarize the results.
* (b)  Calculation of the final values.
*/
void
Multi_Scale_Retinex (BYTE *src, int width, int height, int bytes/* =3 for RGB, 4=RGBA*/) {
    
	/* Multi-scale Retinex Color Correction */
    
	
	int           scale,row,col;
    int           weight;
	
    gauss3_coefs  coef;
	
    double        mean;
	double		  var;
	
    int		        mini;
	
	int				range;
	int				  maxi;
	
	int				alpha = 255;
	
	int			  size;
	int			  channelsize;
	
	double *	  dst;
	double *	  in;
	double *	  out;
	
	int channel;
	
	int i; 
	int pos; 
    
	int j;
	int c, ic;
	//* Allocate all the memory needed for algorithm
	int logl;
				
//	BYTE *psrc ;
//	double *pdst;
	
	
	
	size = width * height * bytes ; // / 4 ; <-----/4
	
//	dst = (double*) mem_malloc (size * sizeof (double));
  //  memset(dst, 0, sizeof(double) * size) ;
	
	
    channelsize  = (width * height);
	in  = (double *) mem_malloc (channelsize * sizeof (double));
    memset(in, 0, sizeof(double) * channelsize) ;
    
    out  = (double *) mem_malloc (channelsize * sizeof (double));
    memset(out, 0, sizeof(double) * channelsize) ;
	
    //* 
    //Calculate the scales of filtering according to the
    //number of filter and their distribution.
	debug_set_cur_time();
    retinex_scales_distribution (RetinexScales,
		rvals.nscales, rvals.scales_mode, rvals.scale);
	debug_print_proc_time(" retinex_scales_distribution");
    /*
    Filtering according to the various scales.
    Summerize the results of the various filters according to a
    specific weight(here equivalent for all).
	*/
    
	weight = 1 /  rvals.nscales;
    
	/*
    The recursive filtering algorithm needs different coefficients according
    to the selected scale (~ = standard deviation of Gaussian).
    */ 
	
    for ( channel = 0; channel < 3; channel++) {
		debug_set_cur_time(); 
		
        for ( i = 0, pos = channel; i < channelsize ; i++, pos += bytes) {
            //* 0-255 => 1-256 
			
            in[i] = ((int)src[pos] + 1);
		} 
		debug_print_proc_time(" channel size ");
		
		
		debug_set_cur_time();
        for (scale = 0; scale < rvals.nscales; scale++) {
		
            compute_coefs3 (&coef, RetinexScales[scale]);
            
			//*
            //*  Filtering (smoothing) Gaussian recursive.
            //*  Filter rows first
            //* 
			
            for ( row=0 ;row < height; row++) {
		        pos =  row * width;
				gausssmooth (in + pos, out + pos, width, 1, &coef);
            }
			
            memcpy(in,  out, channelsize * sizeof(double));
            memset(out, 0  , channelsize * sizeof(double));
            
			
			//*  Filtering (smoothing) Gaussian recursive.
			// *  Second columns
			
            for ( col=0; col < width; col++) {
                pos = col;
				gausssmooth(in + pos, out + pos, height, width, &coef);                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
            }
			
			// Summarize the filtered values.
			// In fact one calculates a ratio between the original values and the filtered values.
			
		}
		debug_print_proc_time("\n\n &&Nscales&& ");
		
	}
	
	mem_free(in);
	//	mem_free(out);
	
	
    dst = (double*) mem_malloc (size * sizeof (double));
    memset(dst, 0, sizeof(double) * size) ;
	debug_set_cur_time();
	for ( channel = 0; channel < 3; channel++) {
		for (scale = 0; scale < rvals.nscales; scale++) {
			for (i = 0, pos = channel; i < channelsize ; i++, pos += bytes ) {
			if (src[pos]<240)
				dst[pos] += weight * (math_logf( (float)(src[pos] + 1) ) - (math_logf((float)out[i])));
			}  
		} 
	} 
	debug_print_proc_time("2/4 loop");
	mem_free(out);
	
    /*
    Final calculation with original value and cumulated filter values.
    The parameters gain, alpha and offset are constants.
	*/ 
     
	debug_set_cur_time();
	
	for ( i = 0; i < size; i += bytes ) { 
		
		BYTE *psrc = src+i;
		double *pdst = dst+i;
		
		if (psrc[0]<230 && psrc[1]<230 && psrc[2]<230) 
		{ 
			logl = (int)math_logf((float)(psrc[0] + psrc[1] + psrc[2] + 3));
			pdst[0] = ((math_logf((float)(alpha * (psrc[0]+1))) - logl) * pdst[0]);
			pdst[1] = ((math_logf((float)(alpha * (psrc[1]+1))) - logl) * pdst[1]);
			pdst[2] = ((math_logf((float)(alpha * (psrc[2]+1))) - logl) * pdst[2]);
		
		}
	} 
	
	debug_print_proc_time("3/4 loop");
    ///*
	// ** Adapt the dynamics of the colors according to the statistics of the first and second order.
	// ** The use of the variance makes it possible to control the degree of saturation of the colors.
	 
	debug_set_cur_time();
    compute_mean_var (dst, &mean, &var, size, bytes);
	debug_print_proc_time("compute_mean_var");
    mini = (int)( mean - rvals.cvar*(int)var );
    maxi = (int)( mean + rvals.cvar*(int)var );
    range = maxi - mini;
    
    if (!range)
        range = 1; 
	 
	debug_set_cur_time();
	
    for (i = 0; i < size; i+= bytes  ) {
		 
        BYTE *psrc = src + i;
        double *pdst = dst + i;
        for ( j = 0 ; j < 3 ; j++) {
			c = (int) (255 * ( pdst[j] - mini ) / range);
			ic = c  ;
			psrc[j] = ic<0? 0 : ic>255 ? 255 : ic;
        }
    } 
	 
	debug_print_proc_time("4/4 loop");

	mem_free (dst); 
}
 